/*
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Copyright © 2023 Keith Packard
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials provided
 *    with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 *    contributors may be used to endorse or promote products derived
 *    from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/* List of all IEEE rounding modes */
typedef enum {
	TONEAREST, UPWARD, DOWNWARD, TOWARDZERO
} rounding_mode_t;

/*
 * IEEE-style numbers with explicit sign (for signed 0), along with
 * NaN and infinity values
 */

typedef union {
	real	num;
	void	nan;
	void	inf;
} value_t;

typedef struct {
	bool	sign;
	value_t	u;
} float_t;

/*
 * Construct our fancy float_t from a number
 */
float_t
make_float(real num)
{
	return (float_t) {
		.sign = num < 0,
		.u = { .num = num },
	};
}

typedef struct {
	int	bits;
	int	exp_bits;
	int	min_exp;
	int	max_exp;

	int	first_exp;
	int	last_exp;
} format_t;

typedef union {
	format_t	format;
	void		none;
} format_or_none_t;

format_or_none_t none_format = { .none = <> };

/*
 * Round the given float to the specified sizes under the given
 * rounding mode.
 */

float_t
round(float_t f, format_t format, format_or_none_t i_format, rounding_mode_t rm)
{
	union switch (i_format) {
	case format format:
		f = round(f, format, none_format, rm);
		break;
	case none:
		break;
	}

	int bits = format.bits;

	union switch(f.u) {
	case num x:
		if (f.sign)
			x = -x;
		int exp;
		if (x == 0)
			exp = 0;
		else
			exp = ceil(log2(x));
		int denorm = format.min_exp - exp;

		/* Denorm means our available precision is reduced */
		if (denorm > 0) {
#			printf("denorm %d\n", denorm);
			bits -= denorm;
		}

		/*
		 * Compute the significand. This could use the
		 * mantissa built-in function if 'x' was a float, but
		 * we usually want to use rationals instead to
		 * preserve all of the bits until rounding happens
		 */
		real mant = abs(x / (2**(exp-bits)));

		/*
		 * Split into integer and fractional portions. The
		 * integer portion holds the number of bits in the
		 * eventual result, the fractional portion is used in
		 * rounding decisions.
		 */
		int ipart = floor(mant);
		real fpart = mant - ipart;

#		printf("%a: mant %e ipart %d fpart %f\n", x, mant, ipart, fpart);

		union switch(rm) {
		case TONEAREST:
			if (fpart == 0.5) {
				/* round even when the fraction is exactly 1/2 */
				if ((ipart & 1) != 0)
					ipart = ipart + 1;
			} else if (fpart > 0.5) {
				ipart = ipart + 1;
			}
			break;
		case UPWARD:
			if (!f.sign) {
				if (fpart > 0)
					ipart = ipart + 1;
			} else {
				/*
				 * Large negative values round
				 * up to the negative finite value
				 * of greatest magnitude instead of
				 * rounding down to -infinity
				 */
				if (exp > format.max_exp) {
					exp = format.max_exp;
					ipart = (2**bits) - 1;
				}
			}
			break;
		case DOWNWARD:
			if (f.sign) {
				if (fpart > 0)
					ipart = ipart + 1;
			} else {
				/*
				 * Large positive values round
				 * down to the positive finite value
				 * of greatest magnitude instead of
				 * rounding up to infinity
				 */
				if (exp > format.max_exp) {
					exp = format.max_exp;
					ipart = (2**bits) - 1;
				}
			}
			break;
		case TOWARDZERO:
			/*
			 * Large magnitude values round to the value
			 * of largest magnitude of the appropriate
			 * sign instead of away from zero to
			 * +/-infinity.
			 */
			if (exp > format.max_exp) {
				exp = format.max_exp;
				ipart = (2**bits) - 1;
			}
			break;
		}

		/*
		 * Handle underflow in a way that preserves rounding
		 * to a value of smallest magnitude.
		 */
		if (bits < 0) {
			exp -= bits;
			bits = 0;
		}

#		printf("rounded ipart %d exp %d bits %d\n", ipart, exp, bits);

		/*
		 * Compute the final significand, which
		 * is always >= 0.5 and < 1
		 */
		mant = ipart / (2 ** bits);
		if (mant >= 1) {
			exp++;
			mant /= 2;
		}

		/* Overflow to infinity */
		if (exp > format.max_exp) {
			f.u.inf = <>;
		} else {
			f.u.num = mant * 2 ** exp;
			if (f.sign)
				f.u.num = -f.u.num;
		}
		break;
	case nan:
	case inf:
		break;
	}
	return f;
}

string
strfromfloat(float_t f, string suffix)
{
	union switch (f.u) {
	case num x:
		if (x == 0)
			/*
			 * Make sure zero is printed as 0.0 so that
			 * the suffix works
			 */
			return sprintf("%.1f%s", x, suffix);
		else
			/*
			 * %a format involves conversion to float; the
			 * default has 256 bits of significand, which
			 * is expected to be sufficient for any ieee
			 * target
			 */
			return sprintf("%a%s", x, suffix);
	case nan:
		return sprintf("%s%s", f.sign ? "          -nan" : "           nan", suffix);
	case inf:
		return sprintf("%s%s", f.sign ? "          -inf" : "           inf", suffix);
	}
}

bool
isfinite(float_t f)
{
	union switch (f.u) {
	case num:
		return true;
	default:
		return false;
	}
}

bool
isnan(float_t f)
{
	union switch (f.u) {
	case nan:
		return true;
	default:
		return false;
	}
}

bool
isinf(float_t f)
{
	union switch (f.u) {
	case inf:
		return true;
	default:
		return false;
	}
}

float_t
times(float_t a, float_t b)
{
	if (isnan(a))
		return a;
	if (isnan(b))
		return b;

	bool sign = !(a.sign == b.sign);

	/* Special case inf values -- inf * 0 is nan, but inf * other is inf */
	if (isinf(a)) {
		if (b.u == (value_t.num) 0)
			return (float_t) { .sign = sign, .u = { .nan = <> } };
		return (float_t) { .sign = sign, .u = a.u };
	}
	if (isinf(b)) {
		if (a.u == (value_t.num) 0)
			return (float_t) { .sign = sign, .u = { .nan = <> } };
		return (float_t) { .sign = sign, .u = b.u };
	}
	return (float_t) { .sign = sign, .u = { .num = a.u.num * b.u.num } };
}

float_t
plus(float_t a, float_t b)
{
	if (isnan(a))
		return a;
	if (isnan(b))
		return b;

	if (isinf(a)) {
		/* inf + -inf is NaN */
		if (isinf(b) && a.sign != b.sign)
			return (float_t) { .sign = true, .u = { .nan = <> } };
		return a;
	}
	if (isinf(b)) {
		return b;
	}
	real v = a.u.num + b.u.num;
	bool sign = v < 0;
	if (v == 0)
		sign = a.sign;
	return (float_t) { .sign = sign, .u = { .num = v } };
}

/*
 * Now that we have all of our support functions, the actual fma
 * implementation is pretty simple
 */
float_t
fma(float_t x, float_t y, float_t z)
{
	return plus(times(x, y), z);
}

int next_exp(int e, format_t format)
{
	switch (e) {
	case format.first_exp + 1:
		return format.min_exp - 2;
	case format.min_exp:
		return -1;
	case 1:
		return format.last_exp - 2;
	default:
		return e + 1;
	}
}

/*
 * Usual Exponent range for the specified number of bits in the
 * exponent. Note that the minimum is off-by one for the 80-bit m68k
 * format, which uses a slightly different form for denorm.
 */

int
min_exp(int exp_bits)
{
	return -(2**(exp_bits-1) - 3);
}

int
max_exp(int exp_bits)
{
	return (2**(exp_bits-1));
}

/*
 * Generate a set of test vectors for the specified floating point
 * format
 */
void generate(string suf, format_t format, format_or_none_t i_format)
{
	int bits = format.bits;

	format.first_exp = (format.min_exp - bits - 2);
	format.last_exp = (format.max_exp);

	real val = 1 + 2**-(bits-1);

	int i = 0;

	/* Check +/- z */
	for (int zs = -1; zs <= 1; zs += 2) {
		for (int ze = format.first_exp; ze <= format.last_exp; ze = next_exp(ze, format)) {
			float_t z = round(make_float(zs * val * (2 ** ze)), format, none_format, rounding_mode_t.TONEAREST);
			for (int ye = format.first_exp; ye <= format.last_exp; ye = next_exp(ye, format)) {
				float_t y = round(make_float(val * (2 ** ye)), format, none_format, rounding_mode_t.TONEAREST);
				for (int xs = -1; xs <= 1; xs += 2) {
					for (int xe = format.first_exp; xe <= format.last_exp; xe = next_exp(xe, format)) {
						float_t x = round(make_float(xs * val * (2 ** xe)), format, none_format, rounding_mode_t.TONEAREST);
						printf(" /* %4d */ { %-17s, %-17s, %-17s, {", i, strfromfloat(x, suf), strfromfloat(y, suf), strfromfloat(z, suf));
						float_t r = plus(times(x, y), z);
						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.TONEAREST), suf));
						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.UPWARD), suf));
						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.DOWNWARD), suf));
						printf(" %s" , strfromfloat(round(r, format, i_format, rounding_mode_t.TOWARDZERO), suf));
						printf(" } },\n");
						i++;
					}
				}
			}
		}
	}
}

format_t ieee_32 = {
	.bits = 24,
	.exp_bits = 8,
	.min_exp = min_exp(8),
	.max_exp = max_exp(8),
};

format_t ieee_64 = {
	.bits = 53,
	.exp_bits = 11,
	.min_exp = min_exp(11),
	.max_exp = max_exp(11),
};

format_t ieee_128 = {
	.bits = 113,
	.exp_bits = 15,
	.min_exp = min_exp(15),
	.max_exp = max_exp(15),
};

format_t intel_80 = {
	.bits = 64,
	.exp_bits = 15,
	.min_exp = min_exp(15),
	.max_exp = max_exp(15),
};

format_t moto_80 = {
	.bits = 64,
	.exp_bits = 15,
	.min_exp = -16382,
	.max_exp = max_exp(15),
};

format_or_none_t intel_80_optional = { .format = intel_80 };
format_or_none_t moto_80_optional = { .format = moto_80 };

void main()
{
	printf("/* This file is automatically generated with fma_gen.5c */\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ == 0 && FLT_MANT_DIG == 24\n");
	printf("#define HAVE_FLOAT_FMA_VEC\n");
	printf("TEST_CONST struct fmaf_vec fmaf_vec[] = {\n");
	generate("f", ieee_32, none_format);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ == 2 && FLT_MANT_DIG == 24 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
	printf("#define HAVE_FLOAT_FMA_VEC\n");
	printf("TEST_CONST struct fmaf_vec fmaf_vec[] = {\n");
	generate("f", ieee_32, intel_80_optional);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ == 2 && FLT_MANT_DIG == 24 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
	printf("#define HAVE_FLOAT_FMA_VEC\n");
	printf("TEST_CONST struct fmaf_vec fmaf_vec[] = {\n");
	generate("f", ieee_32, moto_80_optional);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ <= 1 && DBL_MANT_DIG == 53\n");
	printf("#define HAVE_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fma_vec fma_vec[] = {\n");
	generate("", ieee_64, none_format);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ == 2 && DBL_MANT_DIG == 53 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
	printf("#define HAVE_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fma_vec fma_vec[] = {\n");
	generate("", ieee_64, intel_80_optional);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if __FLT_EVAL_METHOD__ == 2 && DBL_MANT_DIG == 53 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
	printf("#define HAVE_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fma_vec fma_vec[] = {\n");
	generate("", ieee_64, moto_80_optional);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fmal_vec fmal_vec[] = {\n");
	generate("l", intel_80, none_format);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fmal_vec fmal_vec[] = {\n");
	generate("l", moto_80, none_format);
	printf("};\n");
	printf("#endif\n");
	printf("\n");

	printf("#if LDBL_MANT_DIG == 113\n");
	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
	printf("TEST_CONST struct fmal_vec fmal_vec[] = {\n");
	generate("l", ieee_128, none_format);
	printf("};\n");
	printf("#endif\n");
}

main();
